import pandas as pd
import numpy as np
import pdb
import matplotlib.pyplot as plt

# 设置字体
plt.rcParams["font.sans-serif"] = "SimHei"
# 默认可以显示负号，增加字体显示后。需对负号正常显示进行设置
plt.rcParams['axes.unicode_minus'] = False

# 打开动画开关
plt.ion()

data = pd.read_csv("../ML_data/company.csv",
                   encoding='ANSI')

# 将超市客户分3组：普通用户、vip、SVIP

# 1. 筛选特征值【不是所有的特征都 有助于 结果分组】
train_X = data[["平均每次消费金额", "平均消费周期（天）"]]


def std_sca(val):
    mean_val = val.mean()
    std_val = val.std()
    return (val - mean_val) / std_val


# 每个数据各自进行标准化
train_X["平均每次消费金额"] = train_X[["平均每次消费金额"]].transform(std_sca)
train_X["平均消费周期（天）"] = train_X["平均消费周期（天）"].transform(std_sca)

print(train_X)


def k_means(center):
    # print(train_X)
    # train_X.values 数据部分--是二维数组；遍历即获得每行的数据
    label = []
    for sample in train_X[["平均每次消费金额", "平均消费周期（天）"]].values:
        # print(sample)
        # pdb.set_trace()
        dis = np.sqrt(((sample - center) ** 2).sum(axis=1))

        # 第一个样本的：[304.13319451 478.73165761 267.18720029]
        # 表明 第一个样本 属于类2
        # 找最小值所在的索引
        # print("距离", dis, "类别", dis.argmin())
        label.append(dis.argmin())
        # break

    train_X["组号"] = label
    # print("第一次聚类后\n", train_X)

    new_center = train_X.groupby(by="组号").mean()

    return new_center.values


# 随机初始化聚类中心
center = np.array([[-0.544520, 3.211470],
                   [3.337761, 2.180376],
                   [0.483499, -0.240453]])


def show_result(train_X, new_center, timer):
    # 清除画布
    plt.cla()
    plt.title("第{}次的结果".format(timer))

    plt.scatter(
        train_X["平均每次消费金额"],
        train_X["平均消费周期（天）"],
        c=train_X["组号"]
    )

    # 绘制聚类中心
    plt.scatter(
        new_center[:, 0],
        new_center[:, 1],
        marker='*',
        s=200,
        # c=[0, 1, 2]  # 3种颜色，颜色使用默认颜色
        c=[3, 4, 5]  # 3种颜色，颜色使用默认颜色
    )
    plt.pause(3)


# 第一次可视化没有组号；初始化组号
train_X["组号"] = 0

timer = 0  # 记录次数
while True:

    show_result(train_X, center, timer)
    timer += 1
    new_center = k_means(center)
    # pdb.set_trace()
    if np.all(center == new_center):
        break

    else:
        # 如果新旧聚类中心不一致；
        # 新的聚类中心  赋值  为旧的聚类中心
        center = new_center

print("聚类一共进行了", timer, "次")
print("聚类结果\n", train_X)
print("聚类中心", new_center)

# show_result(train_X, new_center)

# 1. 筛选特征
# 2. 异常值【了解每列正常值的范围】
# 3. 数据标准化处理
# 4. 聚类算法【5-10组】
# 5. 观察每组特点：关注每组聚类中心【根据数值的意义，越大越小 越小越小--每组得分】

plt.ioff()  # 关闭动画
plt.show()  # 统一显示


data["组号"] = train_X["组号"]
print("原始数据\n", data)
print("聚类中心:\n", data.groupby(by="组号")[["平均每次消费金额","平均消费周期（天）"]].mean())

"""
       平均每次消费金额  平均消费周期（天）
组号                       
0   187.500000  78.500000   普通用户 【散客】
1   789.000000  35.000000   VIP     【普通用户】
2   213.210526  11.157895   SVIP    【VIP】
"""